package local

import (
	"bytes"
	"container/list"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"gitee.com/gitee-go/core"
	"gitee.com/gitee-go/core/bean/oauthBean"
	"gitee.com/gitee-go/core/common"
	"gitee.com/gitee-go/core/model/pipeline"
	comm2 "gitee.com/gitee-go/runner-core/comm"
	"gitee.com/gitee-go/server/comm"
	"gitee.com/gitee-go/server/service"
	"gitee.com/gitee-go/utils/ioex"
	"io/ioutil"
	"net/http"
	"runtime/debug"
	"sync"
	"time"
)

type RefreshEngine struct {
	ctx     context.Context
	tms     time.Time
	repoTms time.Time

	uidlk sync.Mutex
	uids  *list.List
	uidmp map[int64]*comm2.RefreshUserRepo
}

func StartRefreshEngine(ctx context.Context) *RefreshEngine {
	et := &RefreshEngine{
		ctx:   ctx,
		uids:  list.New(),
		uidmp: make(map[int64]*comm2.RefreshUserRepo),
	}
	et.start()
	return et
}
func (c *RefreshEngine) start() {
	if comm.Debugs { // 本地运行的时候为了不和线上的数据库冲突,不刷新数据
		return
	}
	go func() {
		for !ioex.CheckContext(c.ctx) {
			c.run()
			c.runRefreshRepo()
			time.Sleep(time.Second)
		}
	}()
	go func() {
		for !ioex.CheckContext(c.ctx) {
			c.uidlk.Lock()
			ln := c.uids.Len()
			c.uidlk.Unlock()
			if ln > 0 {
				c.runRefreshRepou()
			}
			c.runRefreshRepoCheck()
			time.Sleep(time.Millisecond)
		}
	}()
}

// 每隔10秒,查询user token即将过期的accessToken并使用refreshToken刷新
func (c *RefreshEngine) run() {
	defer func() {
		if err := recover(); err != nil {
			core.LogPnc.Errorf("local.RefreshEngine recover:%+v", err)
			core.LogPnc.Errorf("%s", string(debug.Stack()))
		}
	}()
	if time.Since(c.tms).Seconds() < 10 {
		return
	}
	c.tms = time.Now()

	var ls []*pipeline.TUserToken
	err := comm.DBMain.GetDB().
		Where("expires_in>0 and expires_time<? and (refresh_time is NULL or refresh_time<?)",
			c.tms.Add(-time.Minute*5).Format(common.TimeFmt),
			c.tms.Add(-time.Hour*2).Format(common.TimeFmt)).
		Find(&ls)
	if err != nil {
		core.Log.Errorf("RefreshEngine.run db find err:%v", err)
		return
	}
	core.Log.Debugf(fmt.Sprintf("run refresh len:%d", len(ls)))
	// 循环查询出来将过期的accessToken并使用refreshToken刷新
	for _, v := range ls {
		if v.ExpiresIn <= 0 || v.RefreshToken == "" {
			continue
		}
		if err := c.refreshToken(v); err != nil {
			core.Log.Errorf("RefreshEngine.refreshToken %s(%s) err:%v", v.Nick, v.Name, err)
		}
		v.RefreshTime = time.Now()
		comm.DBMain.GetDB().
			Cols("refresh_time").
			Where("id=?", v.Id).Update(v)
	}
}

func (c *RefreshEngine) refreshToken(tk *pipeline.TUserToken) error {
	par := &oauthBean.ParamOAuth{}
	err := service.GetsParam(oauthBean.ParamOAuthKey, par)
	if err != nil {
		return err
	}
	oapp, ok := par.AppInfo(tk.Type)
	_ = oapp
	if !ok {
		return errors.New("application not found")
	}
	switch tk.Type {
	case common.SourceGitee:
		return c.retkGitee(tk)
	}
	return errors.New("not found source")
}

// 刷新gitee的token
func (c *RefreshEngine) retkGitee(tk *pipeline.TUserToken) error {
	core.Log.Debugf("start refresh %s(%s) access token", tk.Nick, tk.Name)
	buf := &bytes.Buffer{}
	buf.WriteString(fmt.Sprintf("grant_type=refresh_token&refresh_token=%s", tk.RefreshToken))
	res, err := http.Post("https://gitee.com/oauth/token", "application/x-www-form-urlencoded", buf)
	if err != nil {
		return err
	}
	defer res.Body.Close()
	bts, err := ioutil.ReadAll(res.Body)
	if err != nil {
		return err
	}
	if res.StatusCode != 200 {
		return fmt.Errorf("stat(%d) err:%s", res.StatusCode, string(bts))
	}
	info := &oauthBean.TokenInfo{}
	err = json.Unmarshal(bts, info)
	if err != nil {
		return err
	}
	info.CreatedTm = time.Now()
	jsontk := string(bts)

	tk.AccessToken = info.AccessToken
	tk.RefreshToken = info.RefreshToken
	tk.ExpiresIn = info.ExpiresIn
	tk.ExpiresTime = info.CreatedTm.Add(time.Second*time.Duration(tk.ExpiresIn) - time.Second*100)
	//tk.RefreshTime=time.Now()
	tk.Tokens = jsontk
	_, err = comm.DBMain.GetDB().
		Cols("access_token", "refresh_token", "expires_in", "expires_time", "tokens").
		Where("id=?", tk.Id).Update(tk)
	return err
}

// 刷新gitee premuime 的token
func (c *RefreshEngine) retkGiteePre(tk *pipeline.TUserToken, oapp *oauthBean.AppInfo) error {
	core.Log.Debugf("start refresh %s(%s) access token", tk.Nick, tk.Name)
	buf := &bytes.Buffer{}
	buf.WriteString(fmt.Sprintf("grant_type=refresh_token&refresh_token=%s", tk.RefreshToken))
	res, err := http.Post(oapp.SourceHost+"/oauth/token", "application/x-www-form-urlencoded", buf)
	if err != nil {
		return err
	}
	defer res.Body.Close()
	bts, err := ioutil.ReadAll(res.Body)
	if err != nil {
		return err
	}
	if res.StatusCode != 200 {
		return fmt.Errorf("stat(%d) err:%s", res.StatusCode, string(bts))
	}
	info := &oauthBean.TokenInfo{}
	err = json.Unmarshal(bts, info)
	if err != nil {
		return err
	}
	info.CreatedTm = time.Now()
	jsontk := string(bts)

	tk.AccessToken = info.AccessToken
	tk.RefreshToken = info.RefreshToken
	tk.ExpiresIn = info.ExpiresIn
	tk.ExpiresTime = info.CreatedTm.Add(time.Second*time.Duration(tk.ExpiresIn) - time.Second*100)
	//tk.RefreshTime=time.Now()
	tk.Tokens = jsontk
	_, err = comm.DBMain.GetDB().
		Cols("access_token", "refresh_token", "expires_in", "expires_time", "tokens").
		Where("id=?", tk.Id).Update(tk)
	return err
}

func (c *RefreshEngine) runRefreshRepoCheck() {
	c.uidlk.Lock()
	defer c.uidlk.Unlock()
	for k, v := range c.uidmp {
		select {
		case <-v.Done:
			if v.Del || time.Since(v.Times).Seconds() > 30 {
				delete(c.uidmp, k)
				return
			}
		default:

		}
	}
}

// 从队列取出用户id,刷新用户仓库
func (c *RefreshEngine) runRefreshRepou() {
	defer func() {
		if err := recover(); err != nil {
			core.LogPnc.Errorf("local.RefreshEngine recover:%+v", err)
			core.LogPnc.Errorf("%s", string(debug.Stack()))
		}
	}()
	c.uidlk.Lock()
	e := c.uids.Front()
	c.uidlk.Unlock()
	if e == nil {
		return
	}
	uid := e.Value.(int64)
	c.uidlk.Lock()
	rup, ok := c.uidmp[uid]
	c.uids.Remove(e)
	c.uidlk.Unlock()
	if !ok {
		return
	}

	tk := &pipeline.TUserToken{}
	ok, _ = comm.DBMain.GetDB().Where("uid=?", uid).Get(tk)
	if !ok {
		core.Log.Errorf("not found user")
		return
	}

	go func() {
		defer func() {
			rup.Times = time.Now()
			close(rup.Done)
		}()
		err := c.runRefreshRepos(tk)
		if err != nil {
			core.Log.Errorf("RefreshEngine.runRefreshRepou err:%v", err)
		}
	}()
}
func (c *RefreshEngine) runRefreshRepo() {
	defer func() {
		if err := recover(); err != nil {
			core.LogPnc.Errorf("local.RefreshEngine recover:%+v", err)
			core.LogPnc.Errorf("%s", string(debug.Stack()))
		}
	}()
	if time.Since(c.repoTms).Hours() < 2 {
		return
	}
	c.repoTms = time.Now()
	db := comm.DBMain.GetDB()
	count, err := db.Table(&pipeline.TUserToken{}).Count()
	if err != nil {
		core.Log.Errorf("RefreshEngine.runRefreshRepo db find err:%v", err)
		return
	}
	if count <= 0 {
		return
	}
	var page = count / 1000
	if count%1000 > 0 {
		page = page + 1
	}
	ts := time.Now()
	// 异步刷新仓库列表,把任务交给refreshEngine执行,提升体验
	for i := 0; i < int(page) && !ioex.CheckContext(c.ctx); i++ {
		var ls []*pipeline.TUserToken
		err = db.Limit(1000, i*1000).Find(&ls)
		if err != nil {
			core.Log.Errorf("RefreshEngine.runRefreshRepo db find err:%v", err)
			return
		}
		for _, tu := range ls {
			c.PutUid(tu.Uid)
			/*err = c.runRefreshRepos(tu)
			if err != nil {
				core.Log.Errorf("RefreshEngine.runRefreshRepo err:%v", err)
			}*/
		}
	}
	core.Log.Debugf("RefreshEngine RefreshRepo 刷新仓库 end time:%.4fs", time.Since(ts).Seconds())
}

// 实际刷新用户仓库逻辑实现
func (c *RefreshEngine) runRefreshRepos(tu *pipeline.TUserToken) error {
	ts := time.Now()
	defer func() {
		core.Log.Debugf("RefreshEngine runRefreshRepos(%s:%s) end time:%.4fs",
			tu.Name, tu.Nick, time.Since(ts).Seconds())
		if err := recover(); err != nil {
			core.LogPnc.Errorf("local.RefreshEngine runRefreshRepos recover:%+v", err)
			core.LogPnc.Errorf("%s", string(debug.Stack()))
		}
	}()
	core.Log.Debugf("刷新仓库中 user:%v", tu.Nick)
	if tu.AccessToken == "" || tu.Uid == 0 {
		return errors.New("刷新仓库 - 用户未授权")
	}
	appinfo, err := service.GetsParamOAuthKey()
	if err != nil {
		core.Log.Errorf("RefreshRepos.GetsParamOAuthKey err : %v", err)
		return errors.New("刷新仓库失败,OAuthKey错误")
	}
	k, info := appinfo.DefAppInfo()
	cl, err := comm.GetThirdApi(k, info.SourceHost)
	if err != nil {
		core.Log.Errorf("runRefreshRepos.GetThirdApi err : %v", err)
		return errors.New("刷新仓库失败,API错误")
	}
	pages, err := cl.Repositories.GetRepos(tu.AccessToken, tu.Name, common.All, common.Pushed, common.OrderDesc, 1, 20)
	if err != nil {
		return errors.New("刷新仓库失败,请检查用户权限")
	}
	if pages.TotalPages <= 0 {
		return errors.New("刷新仓库失败,用户没有仓库")
	}
	db := comm.DBMain.GetDB()
	tur := &pipeline.TUserRepo{}
	_, err = db.Where("user_id = ?", tu.Uid).Delete(tur)
	if err != nil {
		core.Log.Errorf("RefreshRepos.GetRepos TUserRepo Delete db err : %v", err)
		return fmt.Errorf("刷新仓库失败,请重试")
	}
	core.Log.Debugf("刷新仓库中 user:%v 总页数%v", tu.Nick, pages.TotalPages)

	wg := &sync.WaitGroup{}
	for i := 0; i < int(pages.TotalPages); i++ {
		wg.Add(1)
		go func(pg int) {
			defer wg.Done()
			err := service.RefreshRepo(tu, pg)
			if err != nil {
				core.Log.Errorf("service.RefreshRepo err:%v", err)
			}
		}(i + 1)
	}
	wg.Wait()
	return nil
}

// 把需要刷新的用户加入队列,等待刷新
func (c *RefreshEngine) PutUid(uid int64) {
	if uid < 0 {
		return
	}
	c.uidlk.Lock()
	defer c.uidlk.Unlock()
	rup := comm2.NewRefreshUserRepo()
	rup.Uid = uid
	c.uids.PushBack(rup.Uid)
	c.uidmp[rup.Uid] = rup
}

func (c *RefreshEngine) WaitUid(uid int64) {
	c.uidlk.Lock()
	rup, ok := c.uidmp[uid]
	c.uidlk.Unlock()
	if !ok {
		core.Log.Errorf("RefreshEngine.WaitUid not found uid:%d", uid)
		return
	}
	<-rup.Done
	rup.Del = true
}
